# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
from torch import Tensor
from torch import nn as nn

from mmdet3d.registry import MODELS
from mmdet3d.structures import points_cam2img
from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform

EPS = 1e-6


@MODELS.register_module()
class VoteFusion(nn.Module):
    """Fuse 2d features from 3d seeds.

    Args:
        num_classes (int): Number of classes.
        max_imvote_per_pixel (int): Max number of imvotes.
    """

    def __init__(self,
                 num_classes: int = 10,
                 max_imvote_per_pixel: int = 3) -> None:
        super(VoteFusion, self).__init__()
        self.num_classes = num_classes
        self.max_imvote_per_pixel = max_imvote_per_pixel

    def forward(self, imgs: List[Tensor], bboxes_2d_rescaled: List[Tensor],
                seeds_3d_depth: List[Tensor],
                img_metas: List[dict]) -> Tuple[Tensor]:
        """Forward function.

        Args:
            imgs (List[Tensor]): Image features.
            bboxes_2d_rescaled (List[Tensor]): 2D bboxes.
            seeds_3d_depth (List[Tensor]): 3D seeds.
            img_metas (List[dict]): Meta information of images.

        Returns:
            Tuple[Tensor]:

                - img_features: Concatenated cues of each point.
                - masks: Validity mask of each feature.
        """
        img_features = []
        masks = []
        for i, data in enumerate(
                zip(imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas)):
            img, bbox_2d_rescaled, seed_3d_depth, img_meta = data
            bbox_num = bbox_2d_rescaled.shape[0]
            seed_num = seed_3d_depth.shape[0]

            img_shape = img_meta['img_shape']
            # first reverse the data transformations
            xyz_depth = apply_3d_transformation(
                seed_3d_depth, 'DEPTH', img_meta, reverse=True)

            # project points from depth to image
            depth2img = xyz_depth.new_tensor(img_meta['depth2img'])
            uvz_origin = points_cam2img(xyz_depth, depth2img, True)
            z_cam = uvz_origin[..., 2]
            uv_origin = (uvz_origin[..., :2] - 1).round()

            # rescale 2d coordinates and bboxes
            uv_rescaled = coord_2d_transform(img_meta, uv_origin, True)
            bbox_2d_origin = bbox_2d_transform(img_meta, bbox_2d_rescaled,
                                               False)

            if bbox_num == 0:
                imvote_num = seed_num * self.max_imvote_per_pixel

                # use zero features
                two_cues = torch.zeros((15, imvote_num),
                                       device=seed_3d_depth.device)
                mask_zero = torch.zeros(
                    imvote_num - seed_num, device=seed_3d_depth.device).bool()
                mask_one = torch.ones(
                    seed_num, device=seed_3d_depth.device).bool()
                mask = torch.cat([mask_one, mask_zero], dim=0)
            else:
                # expand bboxes and seeds
                bbox_expanded = bbox_2d_origin.view(1, bbox_num, -1).expand(
                    seed_num, -1, -1)
                seed_2d_expanded = uv_origin.view(seed_num, 1,
                                                  -1).expand(-1, bbox_num, -1)
                seed_2d_expanded_x, seed_2d_expanded_y = \
                    seed_2d_expanded.split(1, dim=-1)

                bbox_expanded_l, bbox_expanded_t, bbox_expanded_r, \
                    bbox_expanded_b, bbox_expanded_conf, bbox_expanded_cls = \
                    bbox_expanded.split(1, dim=-1)
                bbox_expanded_midx = (bbox_expanded_l + bbox_expanded_r) / 2
                bbox_expanded_midy = (bbox_expanded_t + bbox_expanded_b) / 2

                seed_2d_in_bbox_x = (seed_2d_expanded_x > bbox_expanded_l) * \
                    (seed_2d_expanded_x < bbox_expanded_r)
                seed_2d_in_bbox_y = (seed_2d_expanded_y > bbox_expanded_t) * \
                    (seed_2d_expanded_y < bbox_expanded_b)
                seed_2d_in_bbox = seed_2d_in_bbox_x * seed_2d_in_bbox_y

                # semantic cues, dim=class_num
                sem_cue = torch.zeros_like(bbox_expanded_conf).expand(
                    -1, -1, self.num_classes)
                sem_cue = sem_cue.scatter(-1, bbox_expanded_cls.long(),
                                          bbox_expanded_conf)

                # bbox center - uv
                delta_u = bbox_expanded_midx - seed_2d_expanded_x
                delta_v = bbox_expanded_midy - seed_2d_expanded_y

                seed_3d_expanded = seed_3d_depth.view(seed_num, 1, -1).expand(
                    -1, bbox_num, -1)

                z_cam = z_cam.view(seed_num, 1, 1).expand(-1, bbox_num, -1)
                imvote = torch.cat(
                    [delta_u, delta_v,
                     torch.zeros_like(delta_v)], dim=-1).view(-1, 3)
                imvote = imvote * z_cam.reshape(-1, 1)
                imvote = imvote @ torch.inverse(depth2img.t())

                # apply transformation to lifted imvotes
                imvote = apply_3d_transformation(
                    imvote, 'DEPTH', img_meta, reverse=False)

                seed_3d_expanded = seed_3d_expanded.reshape(imvote.shape)

                # ray angle
                ray_angle = seed_3d_expanded + imvote
                ray_angle /= torch.sqrt(torch.sum(ray_angle**2, -1) +
                                        EPS).unsqueeze(-1)

                # imvote lifted to 3d
                xz = ray_angle[:, [0, 2]] / (ray_angle[:, [1]] + EPS) \
                    * seed_3d_expanded[:, [1]] - seed_3d_expanded[:, [0, 2]]

                # geometric cues, dim=5
                geo_cue = torch.cat([xz, ray_angle],
                                    dim=-1).view(seed_num, -1, 5)

                two_cues = torch.cat([geo_cue, sem_cue], dim=-1)
                # mask to 0 if seed not in bbox
                two_cues = two_cues * seed_2d_in_bbox.float()

                feature_size = two_cues.shape[-1]
                # if bbox number is too small, append zeros
                if bbox_num < self.max_imvote_per_pixel:
                    append_num = self.max_imvote_per_pixel - bbox_num
                    append_zeros = torch.zeros(
                        (seed_num, append_num, 1),
                        device=seed_2d_in_bbox.device).bool()
                    seed_2d_in_bbox = torch.cat(
                        [seed_2d_in_bbox, append_zeros], dim=1)
                    append_zeros = torch.zeros(
                        (seed_num, append_num, feature_size),
                        device=two_cues.device)
                    two_cues = torch.cat([two_cues, append_zeros], dim=1)
                    append_zeros = torch.zeros((seed_num, append_num, 1),
                                               device=two_cues.device)
                    bbox_expanded_conf = torch.cat(
                        [bbox_expanded_conf, append_zeros], dim=1)

                # sort the valid seed-bbox pair according to confidence
                pair_score = seed_2d_in_bbox.float() + bbox_expanded_conf
                # and find the largests
                mask, indices = pair_score.topk(
                    self.max_imvote_per_pixel,
                    dim=1,
                    largest=True,
                    sorted=True)

                indices_img = indices.expand(-1, -1, feature_size)
                two_cues = two_cues.gather(dim=1, index=indices_img)
                two_cues = two_cues.transpose(1, 0)
                two_cues = two_cues.reshape(-1, feature_size).transpose(
                    1, 0).contiguous()

                # since conf is ~ (0, 1), floor gives us validity
                mask = mask.floor().int()
                mask = mask.transpose(1, 0).reshape(-1).bool()

            # clear the padding
            img = img[:, :img_shape[0], :img_shape[1]]
            img_flatten = img.reshape(3, -1).float()
            img_flatten /= 255.

            # take the normalized pixel value as texture cue
            uv_rescaled[:, 0] = torch.clamp(uv_rescaled[:, 0].round(), 0,
                                            img_shape[1] - 1)
            uv_rescaled[:, 1] = torch.clamp(uv_rescaled[:, 1].round(), 0,
                                            img_shape[0] - 1)
            uv_flatten = uv_rescaled[:, 1].round() * \
                img_shape[1] + uv_rescaled[:, 0].round()
            uv_expanded = uv_flatten.unsqueeze(0).expand(3, -1).long()
            txt_cue = torch.gather(img_flatten, dim=-1, index=uv_expanded)
            txt_cue = txt_cue.unsqueeze(1).expand(-1,
                                                  self.max_imvote_per_pixel,
                                                  -1).reshape(3, -1)

            # append texture cue
            img_feature = torch.cat([two_cues, txt_cue], dim=0)
            img_features.append(img_feature)
            masks.append(mask)

        return torch.stack(img_features, 0), torch.stack(masks, 0)
